#!/usr/bin/env python
# Software License Agreement (Lesser GPL)
#
# Copyright (C) 2009-2010 Rosen Diankov
#
# ikfast is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# at your option) any later version.
#
# ikfast is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
from __future__ import with_statement # for python 2.5

import sys, copy, time, datetime
try:
    from openravepy.metaclass import AutoReloader
except:
    class AutoReloader:
        pass

# import the correct iktypes from openravepy (if present)
try:
    from openravepy import IkParameterization
    IkType = IkParameterization.Type
except:
    class IkType:
        Transform6D=1
        Rotation3D=2
        Translation3D=3
        Direction3D=4
        Ray4D=5

from sympy import *

try:
    from itertools import izip, combinations
except ImportError:
    def combinations(items,n):
        if n == 0: yield[]
        else:
            for  i in xrange(len(items)):
                for cc in combinations(items[i+1:],n-1):
                    yield [items[i]]+cc

def customcse(exprs,symbols=None):
    replacements,reduced_exprs = cse(exprs,symbols=symbols)
    newreplacements = []
    # look for opany expressions of the order of (x**(1/a))**b, usually computer wants x^(b/a)
    for r in replacements:
        if r[1].is_Pow and r[1].exp.is_real and r[1].base.is_Symbol:
            baseexpr = r[1].base.subs(replacements)
            if baseexpr.is_Pow and baseexpr.exp.is_real:
                newreplacements.append((r[0],baseexpr.base**(r[1].exp*baseexpr.exp)))
                continue
        newreplacements.append(r)
    return newreplacements,reduced_exprs

class CodeGenerator(AutoReloader):
    """Generates VisualBasic code from an AST generated by IKFastSolver"""
    dictequations = [] # dictionary of symbols already written
    symbolgen = cse_main.numbered_symbols('x')
    strprinter = printing.StrPrinter()
    freevars = None # list of free variables in the solution
    freevardependencies = None # list of variables depending on the free variables
    ikreal = 'Double'
    vb6 = False
    globalid = 0
    forloops = []
    dimvariables = []
    rayclassname = 'Ray'

    def _startforloop(self,counter,start,end):
        if self.vb6:
            self.globalid += 1
            self.forloops.append((counter,'looplabel%d'%self.globalid))
            return '%s = %s\nDo While %s < %s\n'%(counter,start,counter,end)
        else:
            self.forloops.append(counter)
            return 'For %s = %s To %s Then\n'%(counter,start,end-1)
    def _endforloop(self):
        if self.vb6:
            counter,label = self.forloops.pop()
            return '%s:\n%s = %s+1\nLoop\n'%(label,counter,counter)
        else:
            counter = self.forloops.pop()
            return 'Next\n'%(counter)
    def _continueforloop(self):
        if self.vb6:
            return 'GoTo %s\n'%(self.forloops[-1][1])
        else:
            return 'Continue For\n'

    def _absname(self):
        return 'Abs' if self.vb6 else 'Math.Abs'
    def _sinname(self):
        return 'Sin' if self.vb6 else 'Math.Sin'
    def _cosname(self):
        return 'Cos' if self.vb6 else 'Math.Cos'

    def generate(self, solvertree, kinematicshash=''):
        print 'generating VisualBasic code...'
        code = """' autogenerated analytical inverse kinematics code from ikfast program
' \\author Rosen Diankov
'
' Licensed under the Apache License, Version 2.0 (the "License");
' you may not use this file except in compliance with the License.
' You may obtain a copy of the License at
'     http://www.apache.org/licenses/LICENSE-2.0
' 
' Unless required by applicable law or agreed to in writing, software
' distributed under the License is distributed on an "AS IS" BASIS,
' WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
' See the License for the specific language governing permissions and
' limitations under the License.
'
' generated %s
"""%(str(datetime.datetime.now()))
        if self.vb6:
            code += """
'' put in IKBaseSolution.cls file
'VERSION 1.0 CLASS
'BEGIN
'  MultiUse = -1  'True
'  Persistable = 0  'NotPersistable
'  DataBindingBehavior = 0  'vbNone
'  DataSourceBehavior  = 0  'vbNone
'  MTSTransactionMode  = 0  'NotAnMTSObject
'END
'Attribute VB_Name = "IKBaseSolution"
'Attribute VB_GlobalNameSpace = False
'Attribute VB_Creatable = True
'Attribute VB_PredeclaredId = False
'Attribute VB_Exposed = False
'Public fmul, foffset As %s
'Public freeind As Integer

'' put in IKSolution.cls file
'VERSION 1.0 CLASS
'BEGIN
'  MultiUse = -1  'True
'  Persistable = 0  'NotPersistable
'  DataBindingBehavior = 0  'vbNone
'  DataSourceBehavior  = 0  'vbNone
'  MTSTransactionMode  = 0  'NotAnMTSObject
'END
'Attribute VB_Name = "IKSolution"
'Attribute VB_GlobalNameSpace = False
'Attribute VB_Creatable = True
'Attribute VB_PredeclaredId = False
'Attribute VB_Exposed = False
'Public basesol() As IKBaseSolution
'Public vfree() As Integer

Const IK2PI As %s =  6.28318530717959
Const IKPI As %s = 3.14159265358979
Const IKPI_2 As %s =  1.57079632679490

"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal)
            code += """
Public Function IKasin(ByVal f As %s) As %s
    If f <= -1 Then
        IKasin = -IKPI_2
    ElseIf f >= 1 Then
        IKasin = IKPI_2
    Else
        IKasin = Atn(f / Sqr(1 - f * f))
    End If
End Function

Public Function IKacos(ByVal f As %s) As %s
    If f <= -1 Then
        IKacos = IKPI
    ElseIf f >= 1 Then
        IKacos = 0
    Else
        IKacos = IKPI_2 - Atn(f / Sqr(1 - f * f))
    End If
End Function

Public Function IKatan2(ByVal y As %s, ByVal x As %s) As %s
    If Not IsNumeric(y) Then
        IKatan2 = IKPI_2
    ElseIf  Not IsNumeric(x) Then
        IKatan2 = 0
    ElseIf x = 0 And y = 0 Then
        IKatan2 = 0
    Else
        If y > 0 Then
            If x >= y Then
                IKatan2 = Atn(y / x)
            ElseIf x <= -y Then
                IKatan2 = Atn(y / x) + Pi
            Else
                IKatan2 = Pi / 2 - Atn(x / y)
            End If
        Else
            If x >= -y Then
                IKatan2 = Atn(y / x)
            ElseIf x <= y Then
                IKatan2 = Atn(y / x) - Pi
            Else
                IKatan2 = -Atn(x / y) - Pi / 2
            End If
        End If
    End If
End Function

Public Function IKsqrt(ByVal f As %s) As %s
    If f <= 0.0 Then
        IKsqrt = 0
    Else
        IKsqrt = Sqr(f)
    End If
End Function

Public Function IKdiv(ByVal f As %s) As %s
    If Abs(f) <= 0.0 Then
        IKdiv = 1.0e30
    Else
        IKdiv = 1.0/f
    End If
End Function

' assumes exp < 0
Public Function IKnegpow(ByVal f As %s, ByVal exp As %s) As %s
    If Abs(f) <= 0.0 Then
        IKnegpow = 1.0e30
    Else
        IKnegpow = f^exp
    End If
End Function
"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal)
        else: # vb.net
            code += """
Imports System

Const IK2PI As %s =  6.28318530717959
Const IKPI As %s = 3.14159265358979
Const IKPI_2 As %s =  1.57079632679490

Public Class IKBaseSolution
    Public fmul, foffset As %s
    Public freeind As Integer
End Class

Public Class IKSolution
    Public basesol() As IKBaseSolution
    Public vfree() As Integer  
End Class
"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal)
            code += """
Public Function IKasin(ByVal f As %s) As %s
    If f <= -1 Then
        IKasin = -IKPI_2
    ElseIf f >= 1 Then
        IKasin = IKPI_2
    Else
        IKasin = Math.Asin(f)
    End If
End Function

Public Function IKacos(f As %s) As %s
    If f <= -1 Then
        IKacos = IKPI
    ElseIf f >= 1 Then
        IKacos = 0
    Else
        IKacos = Math.Acos(f)
    End If
End Function

Public Function IKatan2(y As %s, x As %s) As %s
    If Double.IsNaN(y) Then
        IKatan2 = IKPI_2
    ElseIf  Double.IsNaN(x) Then
        IKatan2 = 0
    Else
        IKatan2 = Math.Atan2(y,x)
    End If
End Function

Public Function IKsqrt(f As %s) As %s
    If f <= 0.0 Then
        IKsqrt = 0
    Else
        IKsqrt = Math.Sqrt(f)
    End If
End Function

Public Function IKdiv(f As %s) As %s
    If Math.Abs(f) <= 0.0 Then
        IKdiv = 1.0e30
    Else
        IKdiv = 1.0/f
    End If
End Function

' assumes exp < 0
Public Function IKnegpow(f As %s, exp As %s) As %s
    If Math.Abs(f) <= 0.0 Then
        IKnegpow = 1.0e30
    Else
        IKnegpow = Math.Pow(f,exp)
    End If
End Function
"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal)
        code += solvertree.generate(self)
        code += solvertree.end(self)
        return code

    def generateChain(self, node):
        assert False
        self.freevars = []
        self.freevardependencies = []
        self.dictequations = []
        self.symbolgen = cse_main.numbered_symbols('x')
        
        code = "Public Function getNumFreeParameters() As Integer\n    getNumFreeParameters = %d\nEnd Function\n"%len(node.freejointvars)
#         if len(node.freejointvars) == 0:
#             code += "IKFAST_API const int* GetFreeIndices() { return NULL; }\n"
#         else:
#             code += "IKFAST_API const int* GetFreeIndices() { static const int freeindices[] = {"
#             for i,freejointvar in enumerate(node.freejointvars):
#                 code += "%d"%(freejointvar[1])
#                 if i < len(node.freejointvars)-1:
#                     code += ", "
#             code += "}; return freeindices; }\n"
        code += 'Public Function getNumJoints() As Integer\n    getNumJoints = %d\nEnd Function\n'%(len(node.freejointvars)+len(node.solvejointvars))
        #code += "Public Function getIKRealSize() As Integer\n    return sizeof(IKReal); }\n\n"
        code += 'Public Function getIKType() As Integer\n    getIKType = %d\nEnd Function\n'%IkType.Transform6D
        code += """
Public Class Matrix4x4
    Public m00 As %s
    Public m01 As %s
    Public m02 As %s
    Public m03 As %s
    Public m10 As %s
    Public m11 As %s
    Public m12 As %s
    Public m13 As %s
    Public m20 As %s
    Public m21 As %s
    Public m22 As %s
    Public m23 As %s
    Public m30 As %s
    Public m31 As %s
    Public m32 As %s
    Public m33 As %s
End Class
"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal)
        # generate the fk
        
        if node.Tfk:
            code += "' solves the forward kinematics equations.\n"
            code += "' pfree is an array specifying the free joints of the chain.\n"
            code += "Public Function fk(j() As %s) As Matrix4x4\n"%(self.ikreal)
            allvars = node.solvejointvars + node.freejointvars
            subexprs,reduced_exprs=customcse (node.Tfk[0:3,0:4].subs([(v[0],Symbol('j[%d]'%v[1])) for v in allvars]))
            outputnames = ['eerot[0]','eerot[1]','eerot[2]','eetrans[0]','eerot[3]','eerot[4]','eerot[5]','eetrans[1]','eerot[6]','eerot[7]','eerot[8]','eetrans[2]']
            fcode = ''
            if len(subexprs) > 0:
                vars = [var for var,expr in subexprs]
                fcode = 'Dim ' + ','.join(str(var) for var,expr in subexprs) + ' As %s\n'%self.ikreal
                for var,expr in subexprs:
                    fcode += self.writeEquations(lambda k: str(var),collect(expr,vars))
            for i in range(len(outputnames)):
                fcode += self.writeEquations(lambda k: outputnames[i],reduced_exprs[i])
            code += self.indentCode(fcode,4)
            code += '}\n\n'
        code += "/// solves the inverse kinematics equations.\n"
        code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
        code += "IKFAST_API bool ik(const IKReal* eetrans, const IKReal* eerot, const IKReal* pfree, std::vector<IKSolution>& vsolutions) {\n"
        code += self._startforloop('dummyiter',0,1)
        fcode = "vsolutions.resize(0); vsolutions.reserve(8);\n"
        fcode += 'IKReal '
        
        for var in node.solvejointvars:
            fcode += '%s, c%s, s%s,\n'%(var[0].name,var[0].name,var[0].name)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s=pfree[%d], c%s=cos(pfree[%d]), s%s=sin(pfree[%d]),\n'%(name,i,name,i,name,i)

        for i in range(3):
            for j in range(3):
                fcode += "new_r%d%d, r%d%d = eerot[%d*3+%d],\n"%(i,j,i,j,i,j)
        fcode += "new_px, new_py, new_pz, px = eetrans[0], py = eetrans[1], pz = eetrans[2];\n\n"
        
        rotsubs = [(Symbol("r%d%d"%(i,j)),Symbol("new_r%d%d"%(i,j))) for i in range(3) for j in range(3)]
        rotsubs += [(Symbol("px"),Symbol("new_px")),(Symbol("py"),Symbol("new_py")),(Symbol("pz"),Symbol("new_pz"))]

        psymbols = ["new_px","new_py","new_pz"]
        for i in range(3):
            for j in range(3):
                fcode += self.writeEquations(lambda k: "new_r%d%d"%(i,j),node.Tee[4*i+j])
            fcode += self.writeEquations(lambda k: psymbols[i],node.Tee[4*i+3])
        for i in range(3):
            for j in range(3):
                fcode += "r%d%d = new_r%d%d; "%(i,j,i,j)
        fcode += "px = new_px; py = new_py; pz = new_pz;\n"

        fcode += self.generateTree(node.jointtree)
        code += self.indentCode(fcode,4) + "}\nreturn vsolutions.size()>0;\n}\n"
        return code
    def endChain(self, node):
        return ""

    def generateIKChainRotation3D(self, node):
        assert False
        self.freevars = []
        self.freevardependencies = []
        self.dictequations = []
        self.symbolgen = cse_main.numbered_symbols('x')
        
        code = "IKFAST_API int getNumFreeParameters() { return %d; }\n"%len(node.freejointvars)
        if len(node.freejointvars) == 0:
            code += "IKFAST_API const int* GetFreeIndices() { return NULL; }\n"
        else:
            code += "IKFAST_API const int* GetFreeIndices() { static const int freeindices[] = {"
            for i,freejointvar in enumerate(node.freejointvars):
                code += "%d"%(freejointvar[1])
                if i < len(node.freejointvars)-1:
                    code += ", "
            code += "}; return freeindices; }\n"
        code += "IKFAST_API int getNumJoints() { return %d; }\n\n"%(len(node.freejointvars)+len(node.solvejointvars))
        code += "IKFAST_API int getIKRealSize() { return sizeof(IKReal); }\n\n"
        code += 'IKFAST_API int getIKType() { return %d; }\n\n'%IkType.Rotation3D
        code += """
Public Class Matrix3x3
    Public m00 As %s
    Public m01 As %s
    Public m02 As %s
    Public m10 As %s
    Public m11 As %s
    Public m12 As %s
    Public m20 As %s
    Public m21 As %s
    Public m22 As %s
End Class
"""%(self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal)
        if node.Rfk:
            code += "/// solves the inverse kinematics equations.\n"
            code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
            code += "IKFAST_API void fk(const IKReal* j, IKReal* eetrans, IKReal* eerot) {\n"
            allvars = node.solvejointvars + node.freejointvars
            subexprs,reduced_exprs=customcse (node.Rfk[0:3,0:3].subs([(v[0],Symbol('j[%d]'%v[1])) for v in allvars]))
            outputnames = ['eerot[0]','eerot[1]','eerot[2]','eerot[3]','eerot[4]','eerot[5]','eerot[6]','eerot[7]','eerot[8]']
            fcode = ''
            if len(subexprs) > 0:
                vars = [var for var,expr in subexprs]
                fcode = 'Dim ' + ','.join(str(var) for var,expr in subexprs) + ' As %s\n'%self.ikreal
                for var,expr in subexprs:
                    fcode += self.writeEquations(lambda k: str(var),collect(expr,vars))
            for i in range(len(outputnames)):
                fcode += self.writeEquations(lambda k: outputnames[i],reduced_exprs[i])
            code += self.indentCode(fcode,4)
            code += '}\n\n'
        code += "/// solves the inverse kinematics equations.\n"
        code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
        code += "IKFAST_API bool ik(const IKReal* eetrans, const IKReal* eerot, const IKReal* pfree, std::vector<IKSolution>& vsolutions) {\n"
        code += "for(int dummyiter = 0; dummyiter < 1; ++dummyiter) {\n"
        fcode = "vsolutions.resize(0); vsolutions.reserve(8);\n"
        fcode += 'IKReal '
        
        for var in node.solvejointvars:
            fcode += '%s, c%s, s%s,\n'%(var[0].name,var[0].name,var[0].name)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s=pfree[%d], c%s=cos(pfree[%d]), s%s=sin(pfree[%d]),\n'%(name,i,name,i,name,i)
        for i in range(3):
            for j in range(3):
                fcode += "new_r%d%d, r%d%d = eerot[%d*3+%d]"%(i,j,i,j,i,j)
                if i == 2 and j == 2:
                    fcode += ';\n\n'
                else:
                    fcode += ',\n'
        
        rotsubs = [(Symbol("r%d%d"%(i,j)),Symbol("new_r%d%d"%(i,j))) for i in range(3) for j in range(3)]
        for i in range(3):
            for j in range(3):
                fcode += self.writeEquations(lambda k: "new_r%d%d"%(i,j),node.Ree[i,j])
        for i in range(3):
            for j in range(3):
                fcode += "r%d%d = new_r%d%d; "%(i,j,i,j)
        fcode += '\n'
        fcode += self.generateTree(node.jointtree)
        code += self.indentCode(fcode,4) + "}\nreturn vsolutions.size()>0;\n}\n"
        return code
    def endIKChainRotation3D(self, node):
        return ""

    def generateIKChainTranslation3D(self, node):
        assert False
        self.freevars = []
        self.freevardependencies = []
        self.dictequations = []
        self.symbolgen = cse_main.numbered_symbols('x')
        
        code = "IKFAST_API int getNumFreeParameters() { return %d; }\n"%len(node.freejointvars)
        if len(node.freejointvars) == 0:
            code += "IKFAST_API const int* GetFreeIndices() { return NULL; }\n"
        else:
            code += "IKFAST_API const int* GetFreeIndices() { static const int freeindices[] = {"
            for i,freejointvar in enumerate(node.freejointvars):
                code += "%d"%(freejointvar[1])
                if i < len(node.freejointvars)-1:
                    code += ", "
            code += "}; return freeindices; }\n"
        code += "IKFAST_API int getNumJoints() { return %d; }\n\n"%(len(node.freejointvars)+len(node.solvejointvars))
        code += "IKFAST_API int getIKRealSize() { return sizeof(IKReal); }\n\n"
        code += 'IKFAST_API int getIKType() { return %d; }\n\n'%IkType.Translation3D
        code += """
Public Class Vector3
    Public x As %s
    Public y As %s
    Public z As %s
End Class
"""%(self.ikreal,self.ikreal,self.ikreal)
        if node.Pfk:
            code += "/// solves the inverse kinematics equations.\n"
            code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
            code += "IKFAST_API void fk(const IKReal* j, IKReal* eetrans, IKReal* eerot) {\n"
            allvars = node.solvejointvars + node.freejointvars
            allsubs = [(v[0],Symbol('j[%d]'%v[1])) for v in allvars]
            eqs = []
            for eq in node.Pfk:
                eqs.append(eq.subs(allsubs))
            subexprs,reduced_exprs=customcse (eqs)
            outputnames = ['eetrans[0]','eetrans[1]','eetrans[2]']
            fcode = ''
            if len(subexprs) > 0:
                vars = [var for var,expr in subexprs]
                fcode = 'Dim ' + ','.join(str(var) for var,expr in subexprs) + ' As %s\n'%self.ikreal
                for var,expr in subexprs:
                    fcode += self.writeEquations(lambda k: str(var),collect(expr,vars))
            for i in range(len(outputnames)):
                fcode += self.writeEquations(lambda k: outputnames[i],reduced_exprs[i])
            code += self.indentCode(fcode,4)
            code += '}\n\n'
        code += "/// solves the inverse kinematics equations.\n"
        code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
        code += "IKFAST_API bool ik(const IKReal* eetrans, const IKReal* eerot, const IKReal* pfree, std::vector<IKSolution>& vsolutions) {\n"
        code += "for(int dummyiter = 0; dummyiter < 1; ++dummyiter) {\n"
        fcode = "vsolutions.resize(0); vsolutions.reserve(8);\n"
        fcode += 'IKReal '
        
        for var in node.solvejointvars:
            fcode += '%s, c%s, s%s,\n'%(var[0].name,var[0].name,var[0].name)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s=pfree[%d], c%s=cos(pfree[%d]), s%s=sin(pfree[%d]),\n'%(name,i,name,i,name,i)
        fcode += "new_px, new_py, new_pz, px = eetrans[0], py = eetrans[1], pz = eetrans[2];\n\n"
        rotsubs = [(Symbol("px"),Symbol("new_px")),(Symbol("py"),Symbol("new_py")),(Symbol("pz"),Symbol("new_pz"))]
        psymbols = ["new_px","new_py","new_pz"]
        for i in range(3):
            fcode += self.writeEquations(lambda k: psymbols[i],node.Pee[i])
        fcode += "px = new_px; py = new_py; pz = new_pz;\n"
        fcode += self.generateTree(node.jointtree)
        code += self.indentCode(fcode,4) + "}\nreturn vsolutions.size()>0;\n}\n"
        return code
    def endIKChainTranslation3D(self, node):
        return ""

    def generateIKChainDirection3D(self, node):
        assert False
        self.freevars = []
        self.freevardependencies = []
        self.dictequations = []
        self.symbolgen = cse_main.numbered_symbols('x')
        
        code = "IKFAST_API int getNumFreeParameters() { return %d; }\n"%len(node.freejointvars)
        if len(node.freejointvars) == 0:
            code += "IKFAST_API const int* GetFreeIndices() { return NULL; }\n"
        else:
            code += "IKFAST_API const int* GetFreeIndices() { static const int freeindices[] = {"
            for i,freejointvar in enumerate(node.freejointvars):
                code += "%d"%(freejointvar[1])
                if i < len(node.freejointvars)-1:
                    code += ", "
            code += "}; return freeindices; }\n"
        code += "IKFAST_API int getNumJoints() { return %d; }\n\n"%(len(node.freejointvars)+len(node.solvejointvars))
        code += "IKFAST_API int getIKRealSize() { return sizeof(IKReal); }\n\n"
        code += 'IKFAST_API int getIKType() { return %d; }\n\n'%IkType.Direction3D
        code += """
Public Class Vector3
    Public x As %s
    Public y As %s
    Public z As %s
End Class
"""%(self.ikreal,self.ikreal,self.ikreal)
        if node.Dfk:
            code += "/// solves the inverse kinematics equations.\n"
            code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
            code += "IKFAST_API void fk(const IKReal* j, IKReal* eetrans, IKReal* eerot) {\n"
            allvars = node.solvejointvars + node.freejointvars
            allsubs = [(v[0],Symbol('j[%d]'%v[1])) for v in allvars]
            eqs = []
            for eq in node.Dfk:
                eqs.append(eq.subs(allsubs))
            subexprs,reduced_exprs=customcse (eqs)
            outputnames = ['eerot[0]','eerot[1]','eerot[2]']
            fcode = ''
            if len(subexprs) > 0:
                vars = [var for var,expr in subexprs]
                fcode = 'Dim ' + ','.join(str(var) for var,expr in subexprs) + ' As %s\n'%self.ikreal
                for var,expr in subexprs:
                    fcode += self.writeEquations(lambda k: str(var),collect(expr,vars))
            for i in range(len(outputnames)):
                fcode += self.writeEquations(lambda k: outputnames[i],reduced_exprs[i])
            code += self.indentCode(fcode,4)
            code += '}\n\n'
        code += "/// solves the inverse kinematics equations.\n"
        code += "/// \\param pfree is an array specifying the free joints of the chain.\n"
        code += "IKFAST_API bool ik(const IKReal* eetrans, const IKReal* eerot, const IKReal* pfree, std::vector<IKSolution>& vsolutions) {\n"
        code += "for(int dummyiter = 0; dummyiter < 1; ++dummyiter) {\n"
        fcode = "vsolutions.resize(0); vsolutions.reserve(8);\n"
        fcode += 'IKReal '
        
        for var in node.solvejointvars:
            fcode += '%s, c%s, s%s,\n'%(var[0].name,var[0].name,var[0].name)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s=pfree[%d], c%s=cos(pfree[%d]), s%s=sin(pfree[%d]),\n'%(name,i,name,i,name,i)

        for i in range(3):
            fcode += "new_r0%d, r0%d = eerot[%d]"%(i,i,i)
            if i == 2:
                fcode += ';\n\n'
            else:
                fcode += ',\n'
        rotsubs = [(Symbol("r%d%d"%(0,i)),Symbol("new_r%d%d"%(0,i))) for i in range(3)]

        for i in range(3):
            fcode += self.writeEquations(lambda k: "new_r%d%d"%(0,i),node.Dee[i])
        for i in range(3):
            fcode += "r0%d = new_r0%d; "%(i,i)

        fcode += self.generateTree(node.jointtree)
        code += self.indentCode(fcode,4) + "}\nreturn vsolutions.size()>0;\n}\n"
        return code
    def endIKChainDirection3D(self, node):
        return ''

    def generateIKChainRay4D(self, node):
        self.freevars = []
        self.freevardependencies = []
        self.dictequations = []
        self.symbolgen = cse_main.numbered_symbols('x')

        code = "Public Function getNumFreeParameters() As Integer\n    getNumFreeParameters = %d\nEnd Function\n"%len(node.freejointvars)
#         if len(node.freejointvars) == 0:
#             code += "IKFAST_API const int* GetFreeIndices() { return NULL; }\n"
#         else:
#             code += "IKFAST_API const int* GetFreeIndices() { static const int freeindices[] = {"
#             for i,freejointvar in enumerate(node.freejointvars):
#                 code += "%d"%(freejointvar[1])
#                 if i < len(node.freejointvars)-1:
#                     code += ", "
#             code += "}; return freeindices; }\n"
        code += 'Public Function getNumJoints() As Integer\n    getNumJoints = %d\nEnd Function\n'%(len(node.freejointvars)+len(node.solvejointvars))
        #code += "Public Function getIKRealSize() As Integer\n    return sizeof(IKReal); }\n\n"
        code += 'Public Function getIKType() As Integer\n    getIKType = %d\nEnd Function\n'%IkType.Ray4D
        
        if self.vb6:
            code += """
'' put in %s.cls file
'VERSION 1.0 CLASS
'BEGIN
'  MultiUse = -1  'True
'  Persistable = 0  'NotPersistable
'  DataBindingBehavior = 0  'vbNone
'  DataSourceBehavior  = 0  'vbNone
'  MTSTransactionMode  = 0  'NotAnMTSObject
'END
'Attribute VB_Name = "%s"
'Attribute VB_GlobalNameSpace = False
'Attribute VB_Creatable = True
'Attribute VB_PredeclaredId = False
'Attribute VB_Exposed = False
'Public x As Double
'Public y As Double
'Public z As Double
'Public i As Double
'Public j As Double
'Public k As Double

"""%(self.rayclassname,self.rayclassname)
        else:
            code += """
Public Class %s
    Public x As %s
    Public y As %s
    Public z As %s
    Public i As %s
    Public j As %s
    Public k As %s
End Class
"""%(self.rayclassname,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal,self.ikreal)
        if node.Dfk and node.Pfk:
            code += "' solves the forward kinematics equations.\n"
            code += "' pfree is an array specifying the free joints of the chain.\n"
            code += "Public Function fk(j() As %s) As %s\n    Dim output As New %s\n"%(self.ikreal,self.rayclassname,self.rayclassname)
            allvars = node.solvejointvars + node.freejointvars
            allsubs = [(v[0],Symbol('j(%d)'%v[1])) for v in allvars]
            eqs = []
            for eq in node.Pfk[0:3]:
                eqs.append(eq.subs(allsubs))
            for eq in node.Dfk[0:3]:
                eqs.append(eq.subs(allsubs))
            subexprs,reduced_exprs=customcse (eqs)
            outputnames = ['output.x','output.y','output.z','output.i','output.j','output.k']
            fcode = ''
            if len(subexprs) > 0:
                vars = [var for var,expr in subexprs]
                fcode = 'Dim ' + ','.join(str(var) for var,expr in subexprs) + ' As %s\n'%self.ikreal
                for var,expr in subexprs:
                    fcode += self.writeEquations(lambda k: str(var),collect(expr,vars))
            for i in range(len(outputnames)):
                fcode += self.writeEquations(lambda k: outputnames[i],reduced_exprs[i])
            code += self.indentCode(fcode,4)
            code += 'Set fk = output\nEnd Function\n'
        code += "' solves the inverse kinematics equations.\n"
        code += "' pfree is an array specifying the free joints of the chain.\n"
        code += "Public Function ik(inray As %s, pfree As %s, ByRef vsolutions() As IKSolution) As Boolean\n"%(self.rayclassname,self.ikreal)
        fcode = "Dim solution As IKSolution\n"
        if self.vb6:
            fcode += "Dim basesol as VARIABLE\n"
        fcode += "Dim numsolutions As Integer\nnumsolutions = 0\n"
        fcode += 'Dim dummyiter As Integer\n'
        fcode += 'Dim evalcond As %s\n'%(self.ikreal)
        fcode += self._startforloop('dummyiter',0,1)
        fcode += 'Dim '        
        for var in node.solvejointvars:
            fcode += '%s, c%s, s%s, numsolutions%s, '%(var[0].name,var[0].name,var[0].name,var[0].name)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s, c%s, s%s,'%(name,name,name)
        for i in range(3):
            fcode += "new_r0%d, r0%d,"%(i,i)
        fcode += "new_px, new_py, new_pz, px, py, pz As %s\n"%(self.ikreal)
        for i in range(len(node.freejointvars)):
            name = node.freejointvars[i][0].name
            fcode += '%s=pfree(%d)\nc%s=%s(pfree(%d))\ns%s=%s(pfree(%d))\n'%(name,i,name,self._cosname(),i,name,self._sinname(),i)
        fcode += "r00 = inray.i\nr01 = inray.j\nr02 = inray.k\n"
        fcode += "px = inray.x\npy = inray.y\npz= inray.z\n"
        rotsubs = [(Symbol("r%d%d"%(0,i)),Symbol("new_r%d%d"%(0,i))) for i in range(3)]
        rotsubs += [(Symbol("px"),Symbol("new_px")),(Symbol("py"),Symbol("new_py")),(Symbol("pz"),Symbol("new_pz"))]
        psymbols = ["new_px","new_py","new_pz"]
        for i in range(3):
            fcode += self.writeEquations(lambda k: "new_r%d%d"%(0,i),node.Dee[i])
            fcode += self.writeEquations(lambda k: psymbols[i],node.Pee[i])
        for i in range(3):
            fcode += "r0%d = new_r0%d\n"%(i,i)
        fcode += "\nDim new_pdotd As %s\nnew_pdotd = new_px*new_r00+new_py*new_r01+new_pz*new_r02\n"%(self.ikreal)
        fcode += "px = new_px-new_pdotd * new_r00\npy = new_py- new_pdotd * new_r01\npz = new_pz - new_pdotd * new_r02\n\n"
        for var in node.solvejointvars:
            name = var[0].name
            fcode += 'Dim %sarray(), c%sarray(), s%sarray(), %seval() As %s\n'%(name,name,name,name,self.ikreal)
            fcode += 'Dim i%s As Integer\n'%(name)
            fcode += 'Dim %smul As %s\n'%(name,self.ikreal)
            fcode += 'Dim %svalid() As %s\n'%(name,self.ikreal)
        self.dimvariables = []
        treecode = self.generateTree(node.jointtree)
        dimvariables = dict(map(lambda i: (i,1),self.dimvariables)).keys()
        if len(dimvariables) > 0:
            fcode += 'Dim %s As %s\n'%(','.join(name for name in dimvariables),self.ikreal)
        fcode += treecode
        code += self.indentCode(fcode,4) + self._endforloop()+ "ik = numsolutions>0\nEnd Function\n"
        return code
    def endIKChainRay4D(self, node):
        return ''

    def generateSolution(self, node,declarearray=True,acceptfreevars=True):
        code = ''
        numsolutions = 0
        eqcode = ''
        name = node.jointname
        node.HasFreeVar = False

        if node.jointeval is not None:
            numsolutions = len(node.jointeval)
            equations = []
            names = []
            for i,expr in enumerate(node.jointeval):
                if acceptfreevars:
                    m = None
                    for freevar in self.freevars:
                        if expr.has_any_symbols(Symbol(freevar)):
                            # has free variables, so have to look for a*freevar+b form
                            a = Wild('a',exclude=[Symbol(freevar)])
                            b = Wild('b',exclude=[Symbol(freevar)])
                            m = expr.match(a*Symbol(freevar)+b)
                            if m is not None:
                                self.freevardependencies.append((freevar,name))
                                assert(len(node.jointeval)==1)
                                code += self.writeEquations(lambda i: '%smul'%name, m[a])
                                code += self.writeEquations(lambda i: name, m[b])
                                node.HasFreeVar = True
                                return code
                            else:
                                print 'failed to extract free variable %s for %s from'%(freevar,node.jointname), expr
    #                             m = dict()
    #                             m[a] = Real(-1,30)
    #                             m[b] = Real(0,30)

                equations.append(expr)
                names.append('%sarray(%d)'%(name,i))
                equations.append(sin(Symbol('%sarray(%d)'%(name,i))))
                names.append('s%sarray(%d)'%(name,i))
                equations.append(cos(Symbol('%sarray(%d)'%(name,i))))
                names.append('c%sarray(%d)'%(name,i))
            eqcode += self.writeEquations(lambda i: names[i], equations)

            if node.AddPiIfNegativeEq:
                for i in range(numsolutions):
                    eqcode += 'If %sarray(%d) > 0 Then\n    %sarray(%d) = %sarray(%d)-IKPI\nElse\n    %sarray(%d)=%sarray(%d)+IKPI\nEnd If\n'%(name,i,name,numsolutions+i,name,i,name,numsolutions+i,name,i)
                    eqcode += 's%sarray(%d) = -s%sarray(%d)\n'%(name,numsolutions+i,name,i)
                    eqcode += 'c%sarray(%d) = -c%sarray(%d)\n'%(name,numsolutions+i,name,i)
                numsolutions *= 2
            
            for i in range(numsolutions):
                if node.IsHinge:
                    eqcode += 'If %sarray(%d) > IKPI Then\n    %sarray(%d)=%sarray(%d)-IK2PI\nElseIf %sarray(%d) < -IKPI Then\n    %sarray(%d)=%sarray(%d)+IK2PI\nEnd If\n'%(name,i,name,i,name,i,name,i,name,i,name,i)
                eqcode += '%svalid(%d) = True\n'%(name,i)
        elif node.jointevalcos is not None:
            numsolutions = 2*len(node.jointevalcos)
            eqcode += self.writeEquations(lambda i: 'c%sarray(%d)'%(name,2*i),node.jointevalcos)
            for i in range(len(node.jointevalcos)):
                eqcode += 'If c%sarray(%d) >= -1.0001 And c%sarray(%d) <= 1.0001 Then\n'%(name,2*i,name,2*i)
                eqcode += '    %svalid(%d) = True\n    %svalid(%d) = True\n'%(name,2*i,name,2*i+1)
                eqcode += '    %sarray(%d) = IKacos(c%sarray(%d))\n'%(name,2*i,name,2*i)
                eqcode += '    s%sarray(%d) = %s(%sarray(%d))\n'%(name,2*i,self._sinname(),name,2*i)
                # second solution
                eqcode += '    c%sarray(%d) = c%sarray(%d)\n'%(name,2*i+1,name,2*i)
                eqcode += '    %sarray(%d) = -%sarray(%d)\n'%(name,2*i+1,name,2*i)
                eqcode += '    s%sarray(%d) = -s%sarray(%d)\n'%(name,2*i+1,name,2*i)
                if self.vb6:
                    eqcode += 'ElseIf Not IsNumeric(c%sarray(%d)) Then\n'%(name,2*i)
                else:
                    eqcode += 'ElseIf Double.IsNaN(c%sarray(%d)) Then\n'%(name,2*i)
                eqcode += '    \' probably any value will work\n'
                eqcode += '    %svalid(%d) = True\n'%(name,2*i)
                eqcode += '    c%sarray(%d) = 1\n    s%sarray(%d) = 0\n    %sarray(%d) = 0\n'%(name,2*i,name,2*i,name,2*i)
                eqcode += 'End If\n'
        elif node.jointevalsin is not None:
            numsolutions = 2*len(node.jointevalsin)
            eqcode += self.writeEquations(lambda i: 's%sarray(%d)'%(name,2*i),node.jointevalsin)
            for i in range(len(node.jointevalsin)):
                eqcode += 'If s%sarray(%d) >= -1.0001 And s%sarray(%d) <= 1.0001 Then\n'%(name,2*i,name,2*i)
                eqcode += '    %svalid(%d) = True\n    %svalid(%d) = True\n'%(name,2*i,name,2*i+1)
                eqcode += '    %sarray(%d) = IKasin(s%sarray(%d))\n'%(name,2*i,name,2*i)
                eqcode += '    c%sarray(%d) = %s(%sarray(%d))\n'%(name,2*i,self._cosname(),name,2*i)
                # second solution
                eqcode += '    s%sarray(%d) = s%sarray(%d)\n'%(name,2*i+1,name,2*i)
                eqcode += '    If %sarray(%d) > 0 Then\n        %sarray(%d)=(IKPI-%sarray(%d))\n    Else\n        %sarray(%d)=(-IKPI-%sarray(%d))\n    End If\n'%(name,2*i,name,2*i+1,name,2*i,name,2*i+1,name,2*i)
                eqcode += '    c%sarray(%d) = -c%sarray(%d)\n'%(name,2*i+1,name,2*i)
                if self.vb6:
                    eqcode += 'ElseIf Not IsNumeric(s%sarray(%d)) Then\n'%(name,2*i)
                else:
                    eqcode += 'ElseIf Double.IsNaN(s%sarray(%d)) Then\n'%(name,2*i)
                eqcode += '    \' probably any value will work\n'
                eqcode += '    %svalid(%d) = True\n'%(name,2*i)
                eqcode += '    c%sarray(%d) = 1\n    s%sarray(%d) = 0\n    %sarray(%d) = 0\n'%(name,2*i,name,2*i,name,2*i)
                eqcode += 'End If\n'

        if not declarearray:
            return eqcode,numsolutions

        code += 'If 1 Then\nReDim %sarray(0 To %d)\nReDim c%sarray(0 To %d)\nReDim s%sarray(0 To %d)\n'%(name,numsolutions-1,name,numsolutions-1,name,numsolutions-1)
        code += 'ReDim %svalid(0 To %d)\n'%(name,numsolutions-1)
        for i in range(numsolutions):
            code += '%svalid(%d) = False\n'%(name,i)
        code += eqcode
        for i,j in combinations(range(numsolutions),2):
            code += 'If %svalid(%d) And %svalid(%d) And %s(c%sarray(%d)-c%sarray(%d)) < 0.0001 And %s(s%sarray(%d)-s%sarray(%d)) < 0.0001 Then\n    %svalid(%d)=False\nEnd If\n'%(name,i,name,j,self._absname(),name,i,name,j,self._absname(),name,i,name,j,name,j)
        code += self._startforloop('i%s'%name,0,numsolutions)
        code += 'If Not %svalid(i%s) Then\n%s\nEnd If\n'%(name,name,self.indentCode(self._continueforloop(),4))
        code += '%s = %sarray(i%s)\nc%s = c%sarray(i%s)\ns%s = s%sarray(i%s)\n'%(name,name,name,name,name,name,name,name,name)
        return code

    def endSolution(self, node):
        if node.HasFreeVar:
            self.freevardependencies.pop()
            return ''
        return self._endforloop()+'End If\n'

    def generateConditionedSolution(self, node):
        name=node.solversolutions[0].jointname
        assert all([name == s.jointname for s in node.solversolutions])
        origequations = copy.copy(self.dictequations)
        maxchecks = max([len(s.checkforzeros) for s in node.solversolutions])
        allnumsolutions = 0
        checkcode = ''
        for solversolution in node.solversolutions:
            assert len(solversolution.checkforzeros) > 0
            self.dictequations = copy.copy(origequations)
            checkcode += 'ReDim %seval(0 To %d)\n'%(name,len(solversolution.checkforzeros)-1)
            checkcode += self.writeEquations(lambda i: '%seval(%d)'%(name,i),solversolution.checkforzeros)
            checkcode += 'If '
            for i in range(len(solversolution.checkforzeros)):
                if i != 0:
                    checkcode += ' And '
                checkcode += '%s(%seval(%d)) %s %f '%('Abs' if self.vb6 else 'Math.Abs',name,i,'<=' if solversolution.FeasibleIsZeros else '>',node.thresh)
            checkcode += ' Then\n'
            scode,numsolutions = self.generateSolution(solversolution,declarearray=False,acceptfreevars=False)
            scode += 'numsolutions%s = %d\n'%(name,numsolutions)
            allnumsolutions = max(allnumsolutions,numsolutions)
            checkcode += self.indentCode(scode,4)
            checkcode += '\nElse\n'
        checkcode += self._continueforloop()  # if got here, then current solution branch is not good, so skip
        checkcode += 'End If\n'*len(node.solversolutions)
        checkcode += 'If numsolutions%s = 0 Then\n    '%name+self._continueforloop()+'End If\n'

        self.dictequations = origequations
        code = 'If 1 Then\nReDim %sarray(0 To %d)\nReDim c%sarray(0 To %d)\nReDim s%sarray(0 To %d)\n'%(name,allnumsolutions-1,name,allnumsolutions-1,name,allnumsolutions-1)
        code += 'ReDim %svalid(0 To %d)\n'%(name,allnumsolutions-1)
        code += 'numsolutions%s = 0\n'%name
        for i in range(allnumsolutions):
            code += '%svalid(%d) = False\n'%(name,i)
        code += self.indentCode(checkcode,4)
        for i,j in combinations(range(allnumsolutions),2):
            code += 'If %svalid(%d) And %svalid(%d) And %s(c%sarray(%d)-c%sarray(%d)) < 0.0001 And %s(s%sarray(%d)-s%sarray(%d)) < 0.0001 Then\n    %svalid(%d)=False\nEnd If\n'%(name,i,name,j,self._absname(),name,i,name,j,self._absname(),name,i,name,j,name,j)
        code += self._startforloop('i%s'%name,0,'numsolutions%s'%name)
        code += 'If Not %svalid(i%s) Then\n%s\nEnd If\n'%(name,name,self.indentCode(self._continueforloop(),4))
        code += '%s = %sarray(i%s)\nc%s = c%sarray(i%s)\ns%s = s%sarray(i%s)\n'%(name,name,name,name,name,name,name,name,name)
        return code

    def endConditionedSolution(self, node):
        return self._endforloop()+'End If\n'

    def generateBranch(self, node):
        origequations = copy.copy(self.dictequations)
        name = node.jointname
        code = 'If 1 Then\n'
        code += self.writeEquations(lambda x: 'evalcond',[node.jointeval])
        numif = 1
        for branch in node.jointbranches:
            branchcode = ''
            self.dictequations = copy.copy(origequations)
            for n in branch[1]:
                branchcode += n.generate(self)
            for n in reversed(branch[1]):
                branchcode += n.end(self)
            branchcode = self.indentCode(branchcode,4)
            if branch[0] is None:
                code += 'If 1 Then\n' + branchcode + 'End If\n'
            else:
                code += 'If evalcond >= %f And evalcond <= %f Then\n'%(name,branch[0]-0.00001,name,branch[0]+0.00001)
                code += branchcode + 'Else\n'
                numif += 1
        code += 'End If\n'*numif
        self.dictequations = origequations
        return code
    def endBranch(self, node):
        return ''
    def generateBranchConds(self, node):
        origequations = copy.copy(self.dictequations)
        code = 'If 1 Then\n'
        for branch in node.jointbranches:
            if branch[0] is None:
                branchcode = 'If 1 Then\n'
            else:
                branchcode = self.writeEquations(lambda x: 'evalcond',branch[0])
                branchcode += 'If %s(evalcond) < 0.00001 Then\n'%('Abs' if self.vb6 else 'Math.Abs')
            self.dictequations = copy.copy(origequations)
            for n in branch[1]:
                branchcode += n.generate(self)
            for n in reversed(branch[1]):
                branchcode += n.end(self)
            code += self.indentCode(branchcode,4)+'Else\n'
        code += 'End If\n'*(len(node.jointbranches)+1)
        self.dictequations = origequations
        return code
    def endBranchConds(self, node):
        return ''
    def generateCheckZeros(self, node):
        origequations = copy.copy(self.dictequations)
        name = node.jointname
        code = 'ReDim %seval(0 To %d)\n'%(name,len(node.jointcheckeqs)-1)
        code += self.writeEquations(lambda i: '%seval(%d)'%(name,i),node.jointcheckeqs)
        hasif=False
        if len(node.jointcheckeqs) > 0:
            code += 'If '
            for i in range(len(node.jointcheckeqs)):
                if i != 0:
                    if node.anycondition:
                        code += ' Or '
                    else:
                        code += ' And '
                code += '%s(%seval(%d)) < %f '%('Abs' if self.vb6 else 'Math.Abs',name,i,node.thresh)
            code += ' Then\n'
            self.dictequations = copy.copy(origequations)
            code += self.indentCode(self.generateTree(node.zerobranch),4)
            code += '\nElse\n'
            hasif = True
        self.dictequations = copy.copy(origequations)
        code += self.indentCode(self.generateTree(node.nonzerobranch),4)
        if hasif:
            code += '\nEnd If\n'
        self.dictequations = origequations
        return 'If 1 Then\n' + self.indentCode(code,4) + 'End If\n'
    def endCheckZeros(self, node):
        return ''
    def generateFreeParameter(self, node):
        #print 'free variable ',node.jointname,': ',self.freevars
        self.freevars.append(node.jointname)
        self.freevardependencies.append((node.jointname,node.jointname))
        code = '%smul = 1\n%s=0\n'%(node.jointname,node.jointname)
        return code+self.generateTree(node.jointtree)
    def endFreeParameter(self, node):
        self.freevars.pop()
        self.freevardependencies.pop()
        return ''
    def generateSetJoint(self, node):
        code = 'If 1 Then\n    %s = %f\n    s%s = %f\n    c%s = %f\n'%(node.jointname,node.jointvalue,node.jointname,sin(node.jointvalue),node.jointname,cos(node.jointvalue))
        return code
    def endSetJoint(self, node):
        return 'End If\n'
    def generateBreak(self,node):
        return self._continueforloop()
    def endBreak(self,node):
        return ''
    def generateRotation(self, node):
        code = ''
        listequations = []
        names = []
        for i in range(3):
            for j in range(3):
                listequations.append(node.T[i,j])
                names.append(Symbol('new_r%d%d'%(i,j)))
        code += self.writeEquations(lambda i: names[i],listequations)
        code += self.generateTree(node.jointtree)
        return code
    def endRotation(self, node):
        return ''
    def generateDirection(self, node):
        code = ''
        listequations = []
        names = []
        for i in range(3):
            listequations.append(node.D[i])
            names.append(Symbol('new_r%d%d'%(0,i)))
        code += self.writeEquations(lambda i: names[i],listequations)
        code += self.generateTree(node.jointtree)
        return code
    def endDirection(self, node):
        return ''
    def startSolution(self,numjointvars,numfreevars):
        code = 'solution = New IKSolution\n'
        if self.vb6:
            code += 'ReDim solution.basesol(0 To %d) As IKBaseSolution\n'%numjointvars
            code += 'ReDim solution.vfree(0 To %d) As Integer\n'%numfreevars
        else:
            code += 'ReDim solution.basesol(0 To %d)\n'%numjointvars
            code += 'ReDim solution.vfree(0 To %d)\n'%numfreevars
        return code
    def generateStoreSolution(self, node):
        code = self.startSolution(len(node.alljointvars), len(self.freevars))
        for i,var in enumerate(node.alljointvars):
            code += 'solution.basesol(%d).foffset = %s\n'%(i,var)            
            vardeps = [vardep for vardep in self.freevardependencies if vardep[1]==var.name]
            if len(vardeps) > 0:
                freevarname = vardeps[0][0]
                ifreevar = [j for j in range(len(self.freevars)) if freevarname==self.freevars[j]]
                code += 'solution.basesol(%d).fmul = %smul\n'%(i,var.name)
                code += 'solution.basesol(%d).freeind = %d\n'%(i,ifreevar[0])
        for i,varname in enumerate(self.freevars):
            ind = [j for j in range(len(node.alljointvars)) if varname==node.alljointvars[j].name]
            code += 'solution.vfree(%d) = %d\n'%(i,ind[0])
        code += 'If numsolutions > 0 Then\n    ReDim Preserve vsolutions(0 To numsolutions) As IKSolution\nElse\n    ReDim vsolutions(0 To 0) As IKSolution\nEnd If\nSet vsolutions(numsolutions) = solution\nnumsolutions = numsolutions + 1\n'
        return code
    def endStoreSolution(self, node):
        return ''
    def generateSequence(self, node):
        code = ''
        for tree in node.jointtrees:
            code += self.generateTree(tree)
        return code
    def endSequence(self, node):
        return ''
    def generateTree(self,tree):
        code = ''
        for n in tree:
            code += n.generate(self)
        for n in reversed(tree):
            code += n.end(self)
        return code
    def writeEquations(self, varnamefn, exprs):
        code = ''
        [replacements,reduced_exprs] = customcse(exprs,symbols=self.symbolgen)
        for rep in replacements:                
            eqns = filter(lambda x: rep[1]-x[1]==0, self.dictequations)
            if len(eqns) > 0:
                self.dictequations.append((rep[0],eqns[0][0]))
                self.dimvariables.append(str(rep[0]))
                code += '%s = %s\n'%(rep[0],eqns[0][0])
            else:
                self.dictequations.append(rep)
                code2,sepcode2 = self.writeExprCode(rep[1])
                self.dimvariables.append(str(rep[0]))
                code += sepcode2+'%s = %s\n'%(rep[0],code2)

        for i,rexpr in enumerate(reduced_exprs):
            code2,sepcode2 = self.writeExprCode(rexpr)
            code += sepcode2+'%s=%s\n'%(varnamefn(i), code2)
        return code

    def writeExprCode(self, expr):
        # go through all arguments and chop them
        code = ''
        sepcode = ''
        if expr.is_Function:
            if expr.func == abs:
                code += 'Abs(' if self.vb6 else 'Math.Abs('
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2
            elif expr.func == acos:
                code += 'IKacos('
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2
                sepcode += 'If (%s) < -1.0001 Or (%s) > 1.0001 Then\n%s\nEnd If\n'%(code2,code2,self.indentCode(self._continueforloop(),4))
            elif expr.func == asin:
                code += 'IKasin('
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2
                sepcode += 'If (%s) < -1.0001 Or (%s) > 1.0001 Then\n%s\nEnd If\n'%(code2,code2,self.indentCode(self._continueforloop(),4))
            elif expr.func == atan2:
                code += 'IKatan2('
                # check for divides by 0 in arguments, this could give two possible solutions?!?
                # if common arguments is nan! solution is lost!
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2+', '
                code3,sepcode2 = self.writeExprCode(expr.args[1])
                code += code3
                sepcode += sepcode2
            elif expr.func == sin:
#                 if expr.args[0].is_Symbol and expr.args[0].name[0] == 'j':
#                     # probably already have initialized
#                     code += '(s%s'%expr.args[0].name
#                 else:
                code += '%s('%(self._sinname())
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2
            elif expr.func == cos:
#                 if expr.args[0].is_Symbol and expr.args[0].name[0] == 'j':
#                     # probably already have initialized
#                     code += '(c%s'%expr.args[0].name
#                 else:
                code += '%s('%(self._cosname())
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2
            elif expr.func == atan2:
                code += 'IKfmod('
                # check for divides by 0 in arguments, this could give two possible solutions?!?
                # if common arguments is nan! solution is lost!
                code2,sepcode = self.writeExprCode(expr.args[0])
                code += code2+', '
                code3,sepcode2 = self.writeExprCode(expr.args[1])
                code += code3
                sepcode += sepcode2
            else:
                code += expr.func.__name__ + '('
                for arg in expr.args:
                    code2,sepcode2 = self.writeExprCode(arg)
                    code += code2
                    sepcode += sepcode2
                    if not arg == expr.args[-1]:
                        code += ','
            return code + ')',sepcode
        elif expr.is_Mul:
            code += '('
            for arg in expr.args:
                code2,sepcode2 = self.writeExprCode(arg)
                code += '('+code2+')'
                sepcode += sepcode2
                if not arg == expr.args[-1]:
                    code += '*'
            return code + ')',sepcode
        elif expr.is_Pow:
            exprbase,sepcode = self.writeExprCode(expr.base)
            if expr.exp.is_real:
                if expr.exp.is_integer and expr.exp.evalf() > 0:
                    code += '('+exprbase+')'
                    for i in range(1,expr.exp.evalf()):
                        code += '*('+exprbase+')'
                    return code,sepcode
                elif expr.exp-0.5 == 0:
                    sepcode += 'If (%s) < -0.00001 Then\n%s\nEnd If'%(exprbase,self.indentCode(self._continueforloop(),4))
                    return 'IKsqrt('+exprbase+')',sepcode
                elif expr.exp < 0:
                    # check if exprbase is 0
                    if expr.exp+1 == 0:
                        return 'IKdiv('+exprbase+')',sepcode
                    return 'IKnegpow('+exprbase+',' + str(expr.exp.evalf()) + ')',sepcode        
            exprexp,sepcode2 = self.writeExprCode(expr.exp)
            sepcode += sepcode2
            return '((' + exprbase + ')^(' + exprexp + '))',sepcode
        elif expr.is_Add:
            code += '('
            for arg in expr.args:
                code2,sepcode2 = self.writeExprCode(arg)
                code += '('+code2+')'
                sepcode += sepcode2
                if not arg == expr.args[-1]:
                    code += '+'
            return code + ')',sepcode

        return self.strprinter.doprint(expr.evalf()),sepcode

    def indentCode(self, code, numspaces):
        lcode = list(code)
        locations = [i for i in range(len(lcode)) if lcode[i]=='\n']
        locations.reverse()
        insertcode = [' ' for i in range(numspaces)]
        for loc in locations:
            lcode[loc+1:0] = insertcode
        lcode[:0] = insertcode
        return ''.join(lcode)

class CodeGeneratorVB6(CodeGenerator):
    def __init__(self):
        self.vb6 = True

class CodeGeneratorVB6Special(CodeGenerator):
    def __init__(self):
        self.vb6 = True
        self.rayclassname = 'Vector'
    def startSolution(self,numjointvars,numfreevars):
        return 'Set solution = New IKSolution\nsolution.rdimsol(%d)\nsolution.rdimfree(%d)\n'%(numjointvars,numfreevars)
    def generateStoreSolution(self, node):
        code = self.startSolution(len(node.alljointvars), len(self.freevars))
        for i,var in enumerate(node.alljointvars):
            code += 'basesol.foffset = %s\n'%(var)            
            vardeps = [vardep for vardep in self.freevardependencies if vardep[1]==var.name]
            if len(vardeps) > 0:
                freevarname = vardeps[0][0]
                ifreevar = [j for j in range(len(self.freevars)) if freevarname==self.freevars[j]]
                code += 'basesol.fmul = %smul\n'%(var.name)
                code += 'basesol.freeind = %d\n'%(ifreevar[0])
            code += "solution.basesol(%d) = basesol\n"%(i)
        for i,varname in enumerate(self.freevars):
            ind = [j for j in range(len(node.alljointvars)) if varname==node.alljointvars[j].name]
            code += 'solution.vfree(%d) = %d\n'%(i,ind[0])
        code += 'If numsolutions > 0 Then\n    ReDim Preserve vsolutions(0 To numsolutions) As IKSolution\nElse\n    ReDim vsolutions(0 To 0) As IKSolution\nEnd If\nSet vsolutions(numsolutions) = solution\nnumsolutions = numsolutions + 1\n'
        return code
